{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp data.mixed_augmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Label-mixing transforms\n",
    "\n",
    "> Callbacks that perform data augmentation by mixing samples in different ways."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "from tsai.data.external import *\n",
    "from tsai.data.core import *\n",
    "from tsai.data.preprocessing import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from torch.distributions.beta import Beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _reduce_loss(loss, reduction='mean'):\n",
    "    \"Reduce the loss based on `reduction`\"\n",
    "    return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class MixHandler1d(Callback):\n",
    "    \"A handler class for implementing mixed sample data augmentation\"\n",
    "    run_valid = False\n",
    "\n",
    "    def __init__(self, alpha=0.5):\n",
    "        self.distrib = Beta(alpha, alpha)\n",
    "\n",
    "    def before_train(self):\n",
    "        self.labeled = True if self.dls.d != 0 else False\n",
    "        if self.labeled:\n",
    "            self.stack_y = getattr(self.learn.loss_func, 'y_int', False)\n",
    "            if self.stack_y: self.old_lf, self.learn.loss_func = self.learn.loss_func, self.lf\n",
    "\n",
    "    def after_train(self):\n",
    "        if self.labeled and self.stack_y: self.learn.loss_func = self.old_lf\n",
    "\n",
    "    def lf(self, pred, *yb):\n",
    "        if not self.training: return self.old_lf(pred, *yb)\n",
    "        with NoneReduce(self.old_lf) as lf: loss = torch.lerp(lf(pred, *self.yb1), lf(pred, *yb), self.lam)\n",
    "        return _reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class MixUp1d(MixHandler1d):\n",
    "    \"Implementation of https://arxiv.org/abs/1710.09412\"\n",
    "\n",
    "    def __init__(self, alpha=.4):\n",
    "        super().__init__(alpha)\n",
    "\n",
    "    def before_batch(self):\n",
    "        lam = self.distrib.sample((self.x.size(0), ))\n",
    "        self.lam = torch.max(lam, 1 - lam).to(self.x.device)\n",
    "        shuffle = torch.randperm(self.x.size(0))\n",
    "        xb1 = self.x[shuffle]\n",
    "        self.learn.xb = L(xb1, self.xb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.x.ndim - 1))\n",
    "        if self.labeled:\n",
    "            self.yb1 = tuple((self.y[shuffle], ))\n",
    "            if not self.stack_y: self.learn.yb = L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1))    \n",
    "                \n",
    "MixUp1D = MixUp1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.utils import *\n",
    "from tsai.models.ResNet import *\n",
    "from tsai.learner import *\n",
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "tfms = [None, Categorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)\n",
    "model = build_model(ResNet, dls=dls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.814203</td>\n",
       "      <td>1.782917</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(dls, model, cbs=MixUp1d(0.4))\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CutMix1d(MixHandler1d):\n",
    "    \"Implementation of `https://arxiv.org/abs/1905.04899`\"\n",
    "\n",
    "    def __init__(self, alpha=1.):\n",
    "        super().__init__(alpha)\n",
    "\n",
    "    def before_batch(self):\n",
    "        bs, *_, seq_len = self.x.size()\n",
    "        self.lam = self.distrib.sample((1, ))\n",
    "        shuffle = torch.randperm(bs)\n",
    "        xb1 = self.x[shuffle]\n",
    "        x1, x2 = self.rand_bbox(seq_len, self.lam)\n",
    "        self.learn.xb[0][..., x1:x2] = xb1[..., x1:x2]\n",
    "        self.lam = (1 - (x2 - x1) / float(seq_len)).item()\n",
    "        if self.labeled:\n",
    "            self.yb1 = tuple((self.y[shuffle], ))\n",
    "            if not self.stack_y:\n",
    "                self.learn.yb = tuple(L(self.yb1, self.yb).map_zip(torch.lerp, weight=unsqueeze(self.lam, n=self.y.ndim - 1)))\n",
    "\n",
    "    def rand_bbox(self, seq_len, lam):\n",
    "        cut_rat = torch.sqrt(1. - lam)\n",
    "        cut_seq_len = torch.round(seq_len * cut_rat).type(torch.long)\n",
    "\n",
    "        # uniform\n",
    "        cx = torch.randint(0, seq_len, (1, ))\n",
    "        x1 = torch.clamp(cx - cut_seq_len // 2, 0, seq_len)\n",
    "        x2 = torch.clamp(cx + cut_seq_len // 2, 0, seq_len)\n",
    "        return x1, x2\n",
    "    \n",
    "CutMix1D = CutMix1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.795043</td>\n",
       "      <td>1.738316</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "tfms = [None, Categorize()]\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, y, tfms=tfms, splits=splits, batch_tfms=batch_tfms)\n",
    "model = build_model(ResNet, dls=dls)\n",
    "learn = Learner(dls, model, cbs=CutMix1d(1.))\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>nan</td>\n",
       "      <td>None</td>\n",
       "      <td>00:01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "batch_tfms = TSStandardize()\n",
    "dls = get_ts_dls(X, splits=splits, batch_tfms=batch_tfms)\n",
    "model = build_model(ResNet, dls=dls)\n",
    "learn = Learner(dls, model, loss_func=MSELossFlat(), cbs=CutMix1D(1.))\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "out = create_scripts(); beep(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
